% Auxiliary Function: honest_weak_factors
%   Description:
%       This function implements the main bias-aware method that considers the case when some of the
%       factors are weak
%   Inputs:
%       Y: N * T array of outcomes
%       X: K * N * T array of regressors
%       R: A positive integer, indicates the number of interactive fixed effects in the estimation
%       Gamma_LS: A preliminary LS estimate of the matrix of fixed effects; it will be computed if
%           not provided
%       alpha: Determines the 1 - alpha coverage of the constructed confidence interval, where
%           default is set 0.05
%       clustered_se_flag: Whether to perform clustered standard error, where default is set 0,
%           which is no
%   Outputs:
%       beta: Point estimates
%       bias: Worst-case bias
%       se: Standard errors
%       LB: Lower bounds of the 1 - alpha bias-aware confidence intervals
%       UB: Upper bounds of the 1 - alpha bias-aware confidence intervals
%       LB_s: Lower bounds of the 1 - alpha standard confidence intervals
%       UB_s: Upper bounds of the 1 - alpha standard confidence intervals
%       A: Matrix of weights

function [beta, bias, se, LB, UB, LB_s, UB_s, A] = honest_weak_factors(Y, X, R, Gamma_LS, alpha, clustered_se_flag)
    % ========================================
    % Parameter setting
    % ========================================
    % Defaults for Missing Inputs
    if nargin < 4
        Gamma_LS = [];
    end
    if nargin < 5
        alpha = 0.05; 
    end
    if nargin < 6
        clustered_se_flag = 0; 
    end
    % Set up dimensions and parameter
    N = size(Y,1);
    T = size(Y,2);
    K = size(X,1);
    b = 2*R*(sqrt(N)+sqrt(T));
    % Prepare arrays for outputs
    A = zeros(K,N,T);
    mu_mse = zeros(K,1);
    mse = zeros(K,1);
    % Set up parameters for optimization
    opt_opts2 = struct('MaxIter_step1', 20 ... 
        , 'method1_name', 'PSwarm1', 'step2_multi_fmincon', 0); 
    % ========================================
    % Compute the optimal weights
    % ========================================
    for k = 1:K
        X_tmp = squeeze(X(k,:,:));
        [V,S,W] = svd(X_tmp,'econ');
        if K == 1
            ZZ = [];
        else
            ZZ = make_ZZ(X([1:k-1,k+1:end],:,:));
        end
        MSE_crit_fn = @(mu) compute_mse(X_tmp, ZZ, mu, b, V, S, W);
        ub = max(diag(S));
        lb = min(diag(S));
        [mu_mse(k), fval] = hybrid_fmincon2(MSE_crit_fn, [min(diag(S)) median(diag(S)) mean(diag(S))], lb, ub, opt_opts2);
        [mse(k), A(k,:,:)] = MSE_crit_fn(mu_mse(k));
    end
    % ========================================
    % Compute the preliminary estimates
    % ========================================
    if isempty(Gamma_LS)
        repMIN = 3;
        repMAX = 10;    
        start_delta = zeros(size(X,1),1);
        [delta_LS,~,lambda,f] = LS_factor(Y, X, R, 'silent', 10^-8, 'm1', start_delta, repMIN, repMAX, 2, 2);
        Gamma_LS = lambda*f';
    end
    Y_tilde_LS = Y - Gamma_LS;
    delta_pre = zeros(K,1);
    Y_diff = Y;
    for k = 1:K
        delta_pre(k) = sum(sum(squeeze(A(k,:,:)).*Y_tilde_LS));
        Y_diff = Y_diff - squeeze(X(k,:,:))*delta_pre(k);  
    end
    % ========================================
    % Compute the final estimates
    % ========================================
    [V,S,W] = svds(Y_diff,R);
    Gamma_pre = V(:,1:R)*S(1:R,1:R)*W(:,1:R)';
    Y_tilde = Y - Gamma_pre;
    U = Y_diff - Gamma_pre;
    C = 2*R*svds(U,1);
    % Prepare arrays for outputs
    beta = zeros(K,1);
    bias = zeros(K,1);
    se = zeros(K,1);
    LB = zeros(K,1);
    UB = zeros(K,1);
    LB_s = zeros(K,1);
    UB_s = zeros(K,1);
    % Loop over all K regressors and compute the estimates
    for k = 1:K
        A_k = squeeze(A(k,:,:));
        beta(k) = sum(sum(A_k.*Y_tilde));
        bias(k) = C*svds(A_k,1);
        if clustered_se_flag
            se(k) = sqrt(sum((sum(A_k.*U,2)).^2));
        else
            se(k) = sqrt(sum(sum((A_k.*U).^2)));
        end    
        LB(k) = beta(k) - bias(k) - se(k)*norminv(1-alpha/2);
        UB(k) = beta(k) + bias(k) + se(k)*norminv(1-alpha/2);
        LB_s(k) = beta(k) - se(k)*norminv(1-alpha/2);
        UB_s(k) = beta(k) + se(k)*norminv(1-alpha/2);    
    end
end


%%
% Auxiliary Function: make_ZZ
%   Description:
%       This function converts a K * N * T tensor to (N * T) * K matrix by stacking them
%   Inputs:
%       Z: K * N * T array
%   Outputs:
%       ZZ: (N * T) * K array

function ZZ = make_ZZ(Z)
    N = size(Z,2);
    T = size(Z,3);
    ZZ = zeros(N*T, size(Z,1));
    for k = 1:size(Z,1)
        ZZ(:,k) = reshape(squeeze(Z(k,:,:))', [N*T, 1]);
    end    
end


%%
% Auxiliary Function: compute_mse
%   Description:
%       This function compute the weights A and the mean squared error defined by definition 2.2 in
%       the paper, where the detailed computation procedure follows appendix B: nuclear norm
%       regularized "partialling out" regression 
%   Inputs:
%       X: N * T matrix
%       ZZ: (N * T) * K array
%       mu: Scalar, the penalty on the nuclear norm
%       b: Scalar, the tuning parameter
%       V: V in the singular value decomposition of SVD(X) = VSW'
%       S: S in the singular value decomposition of SVD(X) = VSW'
%       W: W in the singular value decomposition of SVD(X) = VSW'
%   Outputs:
%       A: Weights defined by definition 2.2 in the paper
%       mse: Mean squared error defined by definition 2.2 in the paper

function [mse, A] = compute_mse(X, ZZ, mu, b, V, S, W)
    % Extract the dimensions from the imputs
    N = size(X,1);
    T = size(X,2);
    % Computation procedure follows appendix B in the paper
    if ~isempty(ZZ)
        [gamma, X_r_gamma] = compute_gamma(X, ZZ, X);
        tol = 1e-4;
        delta_gamma = 1;        
        while delta_gamma > tol
            X_r_pi = compute_X_r_pi(X_r_gamma, mu, X);
            [gamma_u, X_r_gamma] = compute_gamma(X_r_pi, ZZ, X);
            delta_gamma = max(abs(gamma - gamma_u));
            gamma = gamma_u;
        end
        Omega = X_r_pi - reshape(ZZ*gamma, [T, N])';        
    else
        Omega = V*diag(min(mu, diag(S)))*W';
    end
    % Following definition 2.2 in the paper
    A = Omega / (sum(sum(X.*Omega))); % Normalize A to have <A,X>_F = 1, also see appendix B.2
    s1a_scaling = N*T / sqrt(max(N,T));
    s1a = svds(s1a_scaling*A, 1, 'largest', 'Tolerance', 1e-3) / s1a_scaling;
    mse = (min(N,T))^2 * (b^2*s1a^2 + sum(sum(A.^2)));
end


%%
% Auxiliary Function: compute_X_r_pi
%   Description:
%       This function compute the X after partialling out the potential correlation with the
%       estimation error and perform the nuclear norm regularization
%   Inputs:
%       X_r_gamma: N * T matrix, X after partialling out the potential correlation with the
%           estimation error gamma
%       mu: Scalar, the penalty on the nuclear norm
%       X: N * T matrix
%   Outputs:
%       X_r_pi: X after partialling out the potential correlation with the estimation error and
%           performing the nuclear norm regularization

function X_r_pi = compute_X_r_pi(X_r_gamma, mu, X)
    [V, S, W] = svd(X_r_gamma, 'econ');
    X_r_pi = X - V*diag(max(diag(S)-mu, 0))*W';
end


%%
% Auxiliary Function: compute_gamma
%   Description:
%       This function compute the gamma, which is the error component that can be correlated with X
%       and Z; it also partials out the potential correlation of X with the estimation error gamma 
%   Inputs:
%       X_r_pi: N * T matrix, X after partialling out the potential correlation with the estimation
%           error and performing the nuclear norm regularization
%       ZZ: (N * T) * (K - 1) matrix
%       X: N * T matrix
%   Outputs:
%       gamma: Error component that can be correlated with X and Z
%       X_r_gamma: X after partialling out the potential correlation with the estimation error Gamma

function [gamma, X_r_gamma] = compute_gamma(X_r_pi, ZZ, X)
    XX_r_pi = reshape(X_r_pi', [], 1);
    gamma = ZZ \ XX_r_pi;
    X_r_gamma = X - reshape(ZZ*gamma, [size(X,2), size(X,1)])';
end
